{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "authorship_tag": "ABX9TyMDQAGwDnmThXqSHdNwTwvq",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/sugarforever/wtf-langchain/blob/main/09_Callbacks/09_Callbacks.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# LangChain Callback 示例"
      ],
      "metadata": {
        "id": "GIUvp6uXWAcn"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 准备环境"
      ],
      "metadata": {
        "id": "h4Ps0wdmWFGh"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "1. 安装langchain版本0.0.235，以及openai"
      ],
      "metadata": {
        "id": "E8OKcwR6VdHL"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "yDa7VdF9SlAm",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "bf91585d-8b22-4bd9-d3ab-111d991cbd9b"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\u001b[?25l     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/1.3 MB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K     \u001b[91m━━━━━━\u001b[0m\u001b[91m╸\u001b[0m\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.2/1.3 MB\u001b[0m \u001b[31m6.5 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K     \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[91m╸\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m21.7 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m17.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m73.6/73.6 kB\u001b[0m \u001b[31m8.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m90.0/90.0 kB\u001b[0m \u001b[31m10.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m49.4/49.4 kB\u001b[0m \u001b[31m5.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h"
          ]
        }
      ],
      "source": [
        "!pip install -q -U langchain==0.0.235 openai"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "2. 设置OPENAI API Key"
      ],
      "metadata": {
        "id": "N2-YYQa1s_jQ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "\n",
        "os.environ['OPENAI_API_KEY'] = \"您的有效openai api key\""
      ],
      "metadata": {
        "id": "OSGQIm6FtD8A"
      },
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 示例代码"
      ],
      "metadata": {
        "id": "i0wXx4wfWIjF"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "1. 内置回调处理器 `StdOutCallbackHandler`"
      ],
      "metadata": {
        "id": "lEPqka5QWcph"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from langchain.callbacks import StdOutCallbackHandler\n",
        "from langchain.chains import LLMChain\n",
        "from langchain.llms import OpenAI\n",
        "from langchain.prompts import PromptTemplate\n",
        "\n",
        "handler = StdOutCallbackHandler()\n",
        "llm = OpenAI()\n",
        "prompt = PromptTemplate.from_template(\"Who is {name}?\")\n",
        "chain = LLMChain(llm=llm, prompt=prompt, callbacks=[handler])\n",
        "chain.run(name=\"Super Mario\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 202
        },
        "id": "jhoNdmhQs1J3",
        "outputId": "3e110c89-750d-4f64-c0b9-8bc49bb24895"
      },
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n",
            "\n",
            "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
            "Prompt after formatting:\n",
            "\u001b[32;1m\u001b[1;3mWho is Super Mario?\u001b[0m\n",
            "\n",
            "\u001b[1m> Finished chain.\u001b[0m\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'\\n\\nSuper Mario is the protagonist of the popular video game franchise of the same name created by Nintendo. He is a fictional character who stars in video games, television shows, comic books, and films. He is a plumber who is usually portrayed as a portly Italian-American, who is often accompanied by his brother Luigi. He is well known for his catchphrase \"It\\'s-a me, Mario!\"'"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            }
          },
          "metadata": {},
          "execution_count": 3
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "2. 自定义回调处理器\n",
        "\n",
        "我们来实现一个处理器，统计每次 `LLM` 交互的处理时间。"
      ],
      "metadata": {
        "id": "NS9I0GYhs4Ny"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from langchain.callbacks.base import BaseCallbackHandler\n",
        "import time\n",
        "\n",
        "class TimerHandler(BaseCallbackHandler):\n",
        "\n",
        "    def __init__(self) -> None:\n",
        "        super().__init__()\n",
        "        self.previous_ms = None\n",
        "        self.durations = []\n",
        "\n",
        "    def current_ms(self):\n",
        "        return int(time.time() * 1000 + time.perf_counter() % 1 * 1000)\n",
        "\n",
        "    def on_chain_start(self, serialized, inputs, **kwargs) -> None:\n",
        "        self.previous_ms = self.current_ms()\n",
        "\n",
        "    def on_chain_end(self, outputs, **kwargs) -> None:\n",
        "        if self.previous_ms:\n",
        "          duration = self.current_ms() - self.previous_ms\n",
        "          self.durations.append(duration)\n",
        "\n",
        "    def on_llm_start(self, serialized, prompts, **kwargs) -> None:\n",
        "        self.previous_ms = self.current_ms()\n",
        "\n",
        "    def on_llm_end(self, response, **kwargs) -> None:\n",
        "        if self.previous_ms:\n",
        "          duration = self.current_ms() - self.previous_ms\n",
        "          self.durations.append(duration)"
      ],
      "metadata": {
        "id": "WJQTUyrnwIf6"
      },
      "execution_count": 20,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "llm = OpenAI()\n",
        "timerHandler = TimerHandler()\n",
        "prompt = PromptTemplate.from_template(\"What is the HEX code of color {color_name}?\")\n",
        "chain = LLMChain(llm=llm, prompt=prompt, callbacks=[timerHandler])\n",
        "response = chain.run(color_name=\"blue\")\n",
        "print(response)\n",
        "response = chain.run(color_name=\"purple\")\n",
        "print(response)\n",
        "\n",
        "timerHandler.durations"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "B7Y9E34VPObp",
        "outputId": "5d5f7945-8ad2-43ce-ce52-b6787ae61dac"
      },
      "execution_count": 17,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "{'color_name': 'blue'}\n",
            "{'text': '\\n\\nThe HEX code of color blue is #0000FF.'}\n",
            "\n",
            "\n",
            "The HEX code of color blue is #0000FF.\n",
            "{'color_name': 'purple'}\n",
            "{'text': '\\n\\nThe HEX code of color purple is #800080.'}\n",
            "\n",
            "\n",
            "The HEX code of color purple is #800080.\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[1065, 1075]"
            ]
          },
          "metadata": {},
          "execution_count": 17
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "3. `Model` 与 `callbacks`\n",
        "\n",
        "`callbacks` 可以在构造函数中指定，也可以在执行期间的函数调用中指定。\n",
        "\n",
        "请参考如下代码："
      ],
      "metadata": {
        "id": "3U6j88miOHis"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "timerHandler = TimerHandler()\n",
        "llm = OpenAI(callbacks=[timerHandler])\n",
        "response = llm.predict(\"What is the HEX code of color BLACK?\")\n",
        "print(response)\n",
        "\n",
        "timerHandler.durations"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Gh8e7GWzOOkT",
        "outputId": "2af5da44-5a0f-40b0-8e55-1c0643c9c4a8"
      },
      "execution_count": 22,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "['What is the HEX code of color BLACK?']\n",
            "generations=[[Generation(text='\\n\\nThe hex code of black is #000000.', generation_info={'finish_reason': 'stop', 'logprobs': None})]] llm_output={'token_usage': {'prompt_tokens': 10, 'total_tokens': 21, 'completion_tokens': 11}, 'model_name': 'text-davinci-003'} run=None\n",
            "\n",
            "\n",
            "The hex code of black is #000000.\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[1223]"
            ]
          },
          "metadata": {},
          "execution_count": 22
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "timerHandler = TimerHandler()\n",
        "llm = OpenAI()\n",
        "response = llm.predict(\"What is the HEX code of color BLACK?\", callbacks=[timerHandler])\n",
        "print(response)\n",
        "\n",
        "timerHandler.durations"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Wm94hadJR1a2",
        "outputId": "00e4053b-0006-40f9-9c1a-d6ae1afef3c6"
      },
      "execution_count": 25,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "['What is the HEX code of color BLACK?']\n",
            "generations=[[Generation(text='\\n\\nThe Hex code of the color black is #000000.', generation_info={'finish_reason': 'stop', 'logprobs': None})]] llm_output={'token_usage': {'prompt_tokens': 10, 'total_tokens': 23, 'completion_tokens': 13}, 'model_name': 'text-davinci-003'} run=None\n",
            "\n",
            "\n",
            "The Hex code of the color black is #000000.\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[1777]"
            ]
          },
          "metadata": {},
          "execution_count": 25
        }
      ]
    }
  ]
}